from __future__ import absolute_import

import sys
import torch
import torch.nn as nn
import torch.nn.init as init
from torch.autograd import Variable
from torch.nn import functional as F
import numpy as np
from pdb import set_trace as st
from skimage import color
from IPython import embed
from . import pretrained_networks as pn

from losses import masked_lpips as util


def spatial_average(in_tens, mask=None, keepdim=True):
    if mask is None:
        return in_tens.mean([2, 3], keepdim=keepdim)
    else:
        in_tens = in_tens * mask

        # sum masked_in_tens across spatial dims
        in_tens = in_tens.sum([2, 3], keepdim=keepdim)
        in_tens = in_tens / torch.sum(mask)

        return in_tens


def upsample(in_tens, out_H=64):  # assumes scale factor is same for H and W
    in_H = in_tens.shape[2]
    scale_factor = 1.0 * out_H / in_H

    return nn.Upsample(scale_factor=scale_factor, mode="bilinear", align_corners=False)(
        in_tens
    )


# Learned perceptual metric
class PNetLin(nn.Module):
    def __init__(
        self,
        pnet_type="vgg",
        pnet_rand=False,
        pnet_tune=False,
        use_dropout=True,
        spatial=False,
        version="0.1",
        lpips=True,
        vgg_blocks=[1, 2, 3, 4, 5]
    ):
        super(PNetLin, self).__init__()

        self.pnet_type = pnet_type
        self.pnet_tune = pnet_tune
        self.pnet_rand = pnet_rand
        self.spatial = spatial
        self.lpips = lpips
        self.version = version
        self.scaling_layer = ScalingLayer()

        if self.pnet_type in ["vgg", "vgg16"]:
            net_type = pn.vgg16
            self.blocks = vgg_blocks
            self.chns = []
            self.chns = [64, 128, 256, 512, 512]

        elif self.pnet_type == "alex":
            net_type = pn.alexnet
            self.chns = [64, 192, 384, 256, 256]
        elif self.pnet_type == "squeeze":
            net_type = pn.squeezenet
            self.chns = [64, 128, 256, 384, 384, 512, 512]
        self.L = len(self.chns)

        self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)

        if lpips:
            self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
            self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
            self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
            self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
            self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
            self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
            #self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
            if self.pnet_type == "squeeze":  # 7 layers for squeezenet
                self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
                self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
                self.lins += [self.lin5, self.lin6]

    def forward(self, in0, in1, mask=None, retPerLayer=False):
        # blocks: list of layer names

        # v0.0 - original release had a bug, where input was not scaled
        in0_input, in1_input = (
            (self.scaling_layer(in0), self.scaling_layer(in1))
            if self.version == "0.1"
            else (in0, in1)
        )
        outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
        feats0, feats1, diffs = {}, {}, {}
        
        # prepare list of masks at different resolutions
        if mask is not None:
            masks = []
            if len(mask.shape) == 3:
                mask = torch.unsqueeze(mask, axis=0)  # 4D

            for kk in range(self.L):
                N, C, H, W = outs0[kk].shape
                mask = F.interpolate(mask, size=(H, W), mode="nearest")
                masks.append(mask)

        """
        outs0 has 5 feature maps 
        1. [1, 64, 256, 256]
        2. [1, 128, 128, 128]
        3. [1, 256, 64, 64]
        4. [1, 512, 32, 32]
        5. [1, 512, 16, 16]
        """
        for kk in range(self.L):
            feats0[kk], feats1[kk] = (
                util.normalize_tensor(outs0[kk]),
                util.normalize_tensor(outs1[kk]),
            )
            diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
        
        if self.lpips:
            if self.spatial:
                res = [
                    upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2])
                    for kk in range(self.L)
                ]
            else:
                # NOTE: this block is used
                # self.lins has 5 elements, where each element is a layer of LIN
                """
                self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
                self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
                self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
                self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
                self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
                self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4]
                """

                # NOTE:
                # Each lins is applying a 1x1 conv on the spatial tensor to output 1 channel
                # Therefore, to prevent this problem, we can simply mask out the activations
                # in the spatial_average block. Right now, spatial_average does a spatial mean.
                # We can mask out the tensor and then consider only on pixels for the mean op.
                res = [
                    spatial_average(
                        self.lins[kk].model(diffs[kk]),
                        mask=masks[kk] if mask is not None else None,
                        keepdim=True,
                    )
                    for kk in range(self.L)
                ]
        else:
            if self.spatial:
                res = [
                    upsample(diffs[kk].sum(dim=1, keepdim=True), out_H=in0.shape[2])
                    for kk in range(self.L)
                ]
            else:
                res = [
                    spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True)
                    for kk in range(self.L)
                ]
        
        '''
        val = res[0]
        for l in range(1, self.L):
            val += res[l]
        '''
        
        val = 0.0
        for l in range(self.L):
            # l is going to run from 0 to 4
            # check if (l + 1), i.e., [1 -> 5] in self.blocks, then count the loss
            if str(l + 1) in self.blocks:
                val += res[l]

        if retPerLayer:
            return (val, res)
        else:
            return val


class ScalingLayer(nn.Module):
    def __init__(self):
        super(ScalingLayer, self).__init__()
        self.register_buffer(
            "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
        )
        self.register_buffer(
            "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
        )

    def forward(self, inp):
        return (inp - self.shift) / self.scale


class NetLinLayer(nn.Module):
    """ A single linear layer which does a 1x1 conv """

    def __init__(self, chn_in, chn_out=1, use_dropout=False):
        super(NetLinLayer, self).__init__()

        layers = (
            [
                nn.Dropout(),
            ]
            if (use_dropout)
            else []
        )
        layers += [
            nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
        ]
        self.model = nn.Sequential(*layers)


class Dist2LogitLayer(nn.Module):
    """ takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) """

    def __init__(self, chn_mid=32, use_sigmoid=True):
        super(Dist2LogitLayer, self).__init__()

        layers = [
            nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),
        ]
        layers += [
            nn.LeakyReLU(0.2, True),
        ]
        layers += [
            nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),
        ]
        layers += [
            nn.LeakyReLU(0.2, True),
        ]
        layers += [
            nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),
        ]
        if use_sigmoid:
            layers += [
                nn.Sigmoid(),
            ]
        self.model = nn.Sequential(*layers)

    def forward(self, d0, d1, eps=0.1):
        return self.model.forward(
            torch.cat((d0, d1, d0 - d1, d0 / (d1 + eps), d1 / (d0 + eps)), dim=1)
        )


class BCERankingLoss(nn.Module):
    def __init__(self, chn_mid=32):
        super(BCERankingLoss, self).__init__()
        self.net = Dist2LogitLayer(chn_mid=chn_mid)
        # self.parameters = list(self.net.parameters())
        self.loss = torch.nn.BCELoss()

    def forward(self, d0, d1, judge):
        per = (judge + 1.0) / 2.0
        self.logit = self.net.forward(d0, d1)
        return self.loss(self.logit, per)


# L2, DSSIM metrics
class FakeNet(nn.Module):
    def __init__(self, use_gpu=True, colorspace="Lab"):
        super(FakeNet, self).__init__()
        self.use_gpu = use_gpu
        self.colorspace = colorspace


class L2(FakeNet):
    def forward(self, in0, in1, retPerLayer=None):
        assert in0.size()[0] == 1  # currently only supports batchSize 1

        if self.colorspace == "RGB":
            (N, C, X, Y) = in0.size()
            value = torch.mean(
                torch.mean(
                    torch.mean((in0 - in1) ** 2, dim=1).view(N, 1, X, Y), dim=2
                ).view(N, 1, 1, Y),
                dim=3,
            ).view(N)
            return value
        elif self.colorspace == "Lab":
            value = util.l2(
                util.tensor2np(util.tensor2tensorlab(in0.data, to_norm=False)),
                util.tensor2np(util.tensor2tensorlab(in1.data, to_norm=False)),
                range=100.0,
            ).astype("float")
            ret_var = Variable(torch.Tensor((value,)))
            if self.use_gpu:
                ret_var = ret_var.cuda()
            return ret_var


class DSSIM(FakeNet):
    def forward(self, in0, in1, retPerLayer=None):
        assert in0.size()[0] == 1  # currently only supports batchSize 1

        if self.colorspace == "RGB":
            value = util.dssim(
                1.0 * util.tensor2im(in0.data),
                1.0 * util.tensor2im(in1.data),
                range=255.0,
            ).astype("float")
        elif self.colorspace == "Lab":
            value = util.dssim(
                util.tensor2np(util.tensor2tensorlab(in0.data, to_norm=False)),
                util.tensor2np(util.tensor2tensorlab(in1.data, to_norm=False)),
                range=100.0,
            ).astype("float")
        ret_var = Variable(torch.Tensor((value,)))
        if self.use_gpu:
            ret_var = ret_var.cuda()
        return ret_var


def print_network(net):
    num_params = 0
    for param in net.parameters():
        num_params += param.numel()
    print("Network", net)
    print("Total number of parameters: %d" % num_params)
